DALL·E Mini#

This is an simple way of creating DALL·E Mini artworks for generative artists.

Note

Install ekorpkit package first.

Set logging level to Warning, if you don’t want to see verbose logging.

If you run this notebook in Colab, set Hardware accelerator to GPU.

Check your jaxlib version and install the appropriate version. for example, pip install “jax[cuda11_cudnn82]” -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

!pip install ekorpkit[disco] exit()

from ekorpkit import eKonf

eKonf.setLogger("INFO")
eKonf.set_cuda(device="4,5")
print("version:", eKonf.__version__)

is_notebook = eKonf.is_notebook()
is_colab = eKonf.is_colab()
print("is notebook?", is_notebook)
print("is colab?", is_colab)
if is_colab:
    eKonf.mount_google_drive(
        workspace="MyDrive/colab_workspace", project="disco-imagen"
    )

print("evironment varialbles:")
eKonf.print(eKonf.env().dict())
INFO:ekorpkit.base:Setting cuda device to ['A100-SXM4-40GB', 'A100-SXM4-40GB']
INFO:ekorpkit.base:Google Colab not detected.
version: 0.1.36+1.g6577b1a.dirty
is notebook? True
is colab? False
evironment varialbles:
{'CUDA_DEVICE_ORDER': 'PCI_BUS_ID',
 'CUDA_VISIBLE_DEVICES': '4, 5',
 'EKORPKIT_CONFIG_DIR': '/workspace/projects/ekorpkit-book/config',
 'EKORPKIT_DATA_DIR': None,
 'EKORPKIT_LOG_LEVEL': 'INFO',
 'EKORPKIT_PROJECT': 'ekorpkit-book',
 'EKORPKIT_WORKSPACE_ROOT': '/workspace',
 'KMP_DUPLICATE_LIB_OK': 'TRUE',
 'NUM_WORKERS': 230}
cfg = eKonf.compose("model/dalle_mini")
dalle = eKonf.instantiate(cfg)
INFO:ekorpkit.base:Loaded .env from /workspace/projects/ekorpkit-book/config/.env
INFO:ekorpkit.base:setting environment variable CACHED_PATH_CACHE_ROOT to /workspace/.cache/cached_path
INFO:ekorpkit.base:setting environment variable KMP_DUPLICATE_LIB_OK to TRUE
INFO:ekorpkit.base:Google Colab not detected.
INFO:ekorpkit.models.dalle.base:> downloading models...
INFO:ekorpkit.models.dalle.base:> loading modules...
INFO:ekorpkit.utils.lib:dalle_mini not imported, loading from /workspace/projects/ekorpkit-book/disco-imagen/libs/dalle-mini/src/dalle_mini as dalle_mini
INFO:ekorpkit.utils.lib:vqgan_jax.modeling_flax_vqgan not imported, loading from /workspace/projects/ekorpkit-book/disco-imagen/libs/vqgan-jax as vqgan_jax.modeling_flax_vqgan
INFO:ekorpkit.models.dalle.base:> loading models...
INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
INFO:absl:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter CUDA Host
INFO:absl:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
INFO:ekorpkit.models.dalle.mini:Available devices: 6
INFO:ekorpkit.models.dalle.mini:Using 6 devices
INFO:ekorpkit.models.dalle.mini:Devices: [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0), GpuDevice(id=2, process_index=0), GpuDevice(id=3, process_index=0), GpuDevice(id=4, process_index=0), GpuDevice(id=5, process_index=0)]
wandb: Downloading large artifact mega-1-fp16:latest, 4938.53MB. 7 files... Done. 0:0:13.6
Some of the weights of DalleBart were initialized in float16 precision from the model checkpoint at /tmp/tmp1z3gmwn8:
[('lm_head', 'kernel'), ('model', 'decoder', 'embed_positions', 'embedding'), ('model', 'decoder', 'embed_tokens', 'embedding'), ('model', 'decoder', 'final_ln', 'bias'), ('model', 'decoder', 'layernorm_embedding', 'bias'), ('model', 'decoder', 'layernorm_embedding', 'scale'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_0', 'k_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_0', 'out_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_0', 'q_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_0', 'v_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_1', 'k_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_1', 'out_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_1', 'q_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_1', 'v_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'GLU_0', 'Dense_0', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'GLU_0', 'Dense_1', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'GLU_0', 'Dense_2', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'GLU_0', 'LayerNorm_0', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'GLU_0', 'LayerNorm_1', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_0', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_1', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_1', 'scale'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_2', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_3', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_3', 'scale'), ('model', 'encoder', 'embed_positions', 'embedding'), ('model', 'encoder', 'embed_tokens', 'embedding'), ('model', 'encoder', 'final_ln', 'bias'), ('model', 'encoder', 'layernorm_embedding', 'bias'), ('model', 'encoder', 'layernorm_embedding', 'scale'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'FlaxBartAttention_0', 'k_proj', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'FlaxBartAttention_0', 'out_proj', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'FlaxBartAttention_0', 'q_proj', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'FlaxBartAttention_0', 'v_proj', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'GLU_0', 'Dense_0', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'GLU_0', 'Dense_1', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'GLU_0', 'Dense_2', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'GLU_0', 'LayerNorm_0', 'bias'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'GLU_0', 'LayerNorm_1', 'bias'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'LayerNorm_0', 'bias'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'LayerNorm_1', 'bias'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'LayerNorm_1', 'scale')]
You should probably UPCAST the model weights to float32 if this was not intended. See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this.
wandb: Downloading large artifact mega-1-fp16:latest, 4938.53MB. 7 files... Done. 0:0:11.4
# text_prompts = "Members of the Federal Reserve Board are convened to cut the target interest rates, surrounding by doves. matte, artstation"
# batch_name = "dovish"

# text_prompts = 'At a special meeting, hawkish central bankers are poised to raise the target rates. Trending on artstation, matte'
# batch_name = "hawkish"

text_prompts = "Mt. Halla's beautiful flowers, artstation, matte"
batch_name = "halla"
dalle.imagine(
    text_prompts, 
    batch_name=batch_name, 
    n_samples=6, 
    show_collage=True,
)
INFO:ekorpkit.models.dalle.mini: >> elapsed time to diffuse: 0:00:47.700863
INFO:ekorpkit.models.dalle.base:Merging config with args: {}
6 samples generated to /workspace/projects/ekorpkit-book/disco-imagen/outputs/dalle-mini/halla
text prompts: ["Mt. Halla's beautiful flowers, artstation, matte"]
sample image paths:
/workspace/projects/ekorpkit-book/disco-imagen/outputs/dalle-mini/halla/halla(5)_0000.png
/workspace/projects/ekorpkit-book/disco-imagen/outputs/dalle-mini/halla/halla(5)_0001.png
/workspace/projects/ekorpkit-book/disco-imagen/outputs/dalle-mini/halla/halla(5)_0002.png
/workspace/projects/ekorpkit-book/disco-imagen/outputs/dalle-mini/halla/halla(5)_0003.png
/workspace/projects/ekorpkit-book/disco-imagen/outputs/dalle-mini/halla/halla(5)_0004.png
/workspace/projects/ekorpkit-book/disco-imagen/outputs/dalle-mini/halla/halla(5)_0005.png
../../../../_images/dalle-mini_7_2.png

collage generated sample images#

dalle.collage(
    batch_name=batch_name,
    batch_num=4,
    ncols=3,
    num_images=6,
    show_filename=True,
    fontcolor="white",
)
INFO:ekorpkit.models.dalle.mini:Loading config from /workspace/projects/ekorpkit-book/disco-imagen/outputs/dalle-mini/halla/halla(4)_settings.yaml
INFO:ekorpkit.models.dalle.mini:Merging config with diffuse defaults
INFO:ekorpkit.io.file:Processing [6] files from ['halla(4)_*.png']
../../../../_images/dalle-mini_9_1.png

show config#

dalle.show_config(batch_name=batch_name, batch_num=0)
INFO:ekorpkit.models.dalle.mini:Loading config from /workspace/projects/ekorpkit-book/disco-imagen/outputs/dalle-mini/halla/halla(0)_settings.yaml
INFO:ekorpkit.models.dalle.mini:Merging config with diffuse defaults
{'batch_name': 'halla',
 'batch_num': 0,
 'cond_scale': 10.0,
 'gen_top_k': None,
 'gen_top_p': None,
 'n_samples': 6,
 'num_samples': 6,
 'resume_run': False,
 'run_to_resume': 'latest',
 'seed': 827882520,
 'set_seed': 'random_seed',
 'show_collage': True,
 'temperature': None,
 'text_prompts': ["Mt. Halla's beautiful flowers, photorealistic"]}